package DnsServer

import (
	"encoding/binary"
	"encoding/json"
	"errors"
	"fmt"
	"gitee.com/pdudo/SampleDNS2/db"
	_ "gitee.com/pdudo/SampleDNS2/log"
	"gitee.com/pdudo/SampleDNS2/web"
	"gitee.com/pdudo/SampleDNSTool"
	"log"
	"net"
	"strings"
	"sync"
	"time"
)


type DnsUser struct {
	ID uint16
	IP string
}
var cacheLock sync.Mutex
var ProcessID uint64
var processLock sync.RWMutex
var SumCount uint64

func AddGoprocess(id int) {
	processLock.Lock()
	ProcessID += uint64(id)
	SumCount++
	processLock.Unlock()
}

func (user DnsUser)TCPStart(conn net.Conn) {
	AddGoprocess(1)
	defer AddGoprocess(-1)

	log.Println(user.IP, " connect private dns TCP servers")

	HeadBuf := make([]byte,2)

	readLen := 0
	for readLen < 2 {
		n , err := conn.Read(HeadBuf[readLen:2])
		if err != nil {
			log.Println(user.IP , "read Head Buf error " , err)
			conn.Close()
			return
		}
		readLen += n
	}

	msgLen := binary.BigEndian.Uint16(HeadBuf)

	dataBuf := make([]byte,msgLen)

	readLen = 0
	for readLen < int(msgLen) {
		n ,  err := conn.Read(dataBuf[:msgLen])
		if err != nil {
			log.Println(user.IP , "read Data buf error " , err)
			conn.Close()
			return
		}
		readLen += n
	}
	var dnsInfo SampleDNSTool.DNSInfo

	dnsInfo.GetHeader(dataBuf)
	log.Println(user.IP, " QueryId: " , user.ID)

	dnsInfo.GetQuestion(dataBuf)
	log.Println(user.IP , " QueryType: " ,dnsInfo.QueryInfo.QTYPE , " QueryName: " , dnsInfo.QueryInfo.QNAMEString)

	newBuf ,count , err := user.dnsQuery(dnsInfo)
	if err != nil {
		log.Println("dns query error " , err)
	}
	if count == 0 {
		user.proxyDNSTCPServer(conn,HeadBuf,dataBuf)
		return
	}

	dnsInfo.Header.HeaderStatus.QR = 1
	dnsInfo.Header.ANCOUNT = uint16(count)
	dnsInfo.Header.ARCOUNT = 0
	dnsInfo.Header.HeaderStatus.AA  = 1
	dnsInfo.Header.HeaderStatus.RD = 1
	headerBuf := dnsInfo.GenerateHeaders()

	question := dnsInfo.GenerateQuestion()
	newBuf = append(question, newBuf...)


	newBuf = append(headerBuf, newBuf...)


	newBufLen := uint16(len(newBuf))
	headBuf := make([]byte,2)
	binary.BigEndian.PutUint16(headBuf[:],newBufLen)
	conn.Write(headBuf)
	conn.Write(newBuf)
}

func (user *DnsUser)UDPStart(conn *net.UDPConn,connUDP *net.UDPAddr,buf []byte) {
	AddGoprocess(1)
	defer AddGoprocess(-1)

	user.IP = connUDP.IP.String()

	log.Println(user.IP , " connect private dns udp servers")

	var dnsInfo SampleDNSTool.DNSInfo
	dnsInfo.GetHeader(buf)

	user.ID = dnsInfo.Header.ID

	log.Println(user.IP, " QueryId: " , user.ID)

	dnsInfo.GetQuestion(buf)
	log.Println(user.IP , " QueryType: " ,dnsInfo.QueryInfo.QTYPE , " QueryName: " , dnsInfo.QueryInfo.QNAMEString)

	newBuf ,count , err := user.dnsQuery(dnsInfo)
	if err != nil {
		log.Println("dns query error " , err)
	}
	if count == 0 {
		user.proxyDNSUDPServer(conn,connUDP,buf,dnsInfo)
		return
	}

	question := dnsInfo.GenerateQuestion()

	newBuf = append(question, newBuf...)

	dnsInfo.Header.HeaderStatus.QR = 1
	dnsInfo.Header.ANCOUNT = uint16(count)
	dnsInfo.Header.ARCOUNT = 0
	dnsInfo.Header.HeaderStatus.AA  = 1
	dnsInfo.Header.HeaderStatus.RD = 1

	if len(question) + 12 + len(newBuf) > 512 {
		dnsInfo.Header.HeaderStatus.TC = 1
	}
	headerBuf := dnsInfo.GenerateHeaders()

	newBuf = append(headerBuf, newBuf...)

	if 1 == dnsInfo.Header.HeaderStatus.TC {
		conn.WriteToUDP(newBuf[:512],connUDP)
	}
	conn.WriteToUDP(newBuf,connUDP)
	return
}

// Redis 查询
func dbQuery(name string , types uint16)(string,error) {
	if !strings.HasSuffix(name,".") {
		name += "."
	}
	// 主机记录
	records := strings.SplitN(name,".",2)
	var recordType string
	switch types {
	case SampleDNSTool.Type_A:
		recordType = fmt.Sprintf("%s_%s",records[0],"A")
	case SampleDNSTool.Type_CNAME:
		recordType = fmt.Sprintf("%s_%s",records[0],"CNAME")
	}

	// servers
	var servers string

	if "" == records[1] {
		servers = fmt.Sprintf("%s_%s",db.DnsConf.ConfWeb.RegisterKeys,"root")
	} else {
		servers = fmt.Sprintf("%s_%s",db.DnsConf.ConfWeb.RegisterKeys,records[1])
	}

	cmd,err := db.HGet(servers,recordType)
	if err != nil {
		log.Println("redis error " , err)
	}
	if cmd != "" {
		var info web.DnsSaveInfo

		err = json.Unmarshal([]byte(cmd),&info)
		if err != nil {
			log.Println("json unmarshal error " , err)
			return "",errors.New("json unmarshal error")
		}

		return info.Ip , nil
	}
	return "",errors.New("not fount")
}


func (u *DnsUser)dnsQuery(dnsInfo SampleDNSTool.DNSInfo)([]byte,int,error) {

	var buf []byte

	count := 0

	val , err := dbQuery(dnsInfo.QueryInfo.QNAMEString,dnsInfo.QueryInfo.QTYPE)
	if err != nil {
		log.Println("query error" , dnsInfo.QueryInfo.QNAMEString,dnsInfo.QueryInfo.QTYPE)

		//当a记录查询失败的时候，需要查看一下cname
		if SampleDNSTool.Type_A == dnsInfo.QueryInfo.QTYPE {
			// 查询 CNAME
			var user SampleDNSTool.DNSInfo
			user.QueryInfo.QNAMEString = dnsInfo.QueryInfo.QNAMEString
			user.QueryInfo.QTYPE = SampleDNSTool.Type_CNAME
			user.QueryInfo.QCLASS = 1

			newBufCname ,c , _ := u.dnsQuery(user)
			if 0 == c {
				return buf,c,nil
			}
			buf = append(buf,newBufCname...)
			count += c

			_ , id := SampleDNSTool.GetOffsetNames(newBufCname,0)
			//fmt.Println("id: " , id , "queryName: " , tempQuersName)

			rDataNames ,_ := SampleDNSTool.GetOffsetNames(newBufCname,id+10)
			//fmt.Println("rDataNames",rDataNames , "id2: " , id2)

			user.QueryInfo.QNAMEString = rDataNames
			user.QueryInfo.QTYPE = SampleDNSTool.Type_A
			user.QueryInfo.QCLASS = 1

			newBufA ,c , _ := u.dnsQuery(user)
			if 0 == c {
				return buf,c,nil
			}
			buf = append(buf,newBufA...)
			count += c
		}
	} else {
		switch dnsInfo.QueryInfo.QTYPE {
		case SampleDNSTool.Type_A:
			//A记录
			ips := strings.Split(val,",")

			for _,v := range ips {
				count++

				// 生成报文
				dnsInfo.AnswerInfo.NAME = dnsInfo.QueryInfo.QNAMEString
				dnsInfo.AnswerInfo.TYPE = dnsInfo.QueryInfo.QTYPE
				dnsInfo.AnswerInfo.CLASS = dnsInfo.QueryInfo.QCLASS
				dnsInfo.AnswerInfo.TTL = 0
				dnsInfo.AnswerInfo.RDLENGTH = 4
				dnsInfo.AnswerInfo.RDATA = SampleDNSTool.GererateDNSNamesIP(v)

				answer := dnsInfo.GenerateAnswer()
				buf = append(buf, answer...)
			}
		case SampleDNSTool.Type_CNAME:
			// Cname 记录
			cnames := strings.Split(val,",")

			//var data []byte
			for _,v := range cnames {

				count++
				dnsInfo.AnswerInfo.NAME = dnsInfo.QueryInfo.QNAMEString
				dnsInfo.AnswerInfo.TYPE = dnsInfo.QueryInfo.QTYPE
				dnsInfo.AnswerInfo.CLASS = dnsInfo.QueryInfo.QCLASS
				dnsInfo.AnswerInfo.TTL = 0
				dnsInfo.AnswerInfo.RDLENGTH = SampleDNSTool.GererateDNSNamesLen(v)
				dnsInfo.AnswerInfo.RDATA = SampleDNSTool.GererateDNSNames(v)


				answer := dnsInfo.GenerateAnswer()
				buf = append(buf,answer...)
			}
		}
	}
	return buf,count,nil
}


func (user *DnsUser)saveCache(key string,buf []byte)(bool) {
	cacheLock.Lock()
	defer cacheLock.Unlock()
	var err error

	_,err = db.Set(key,buf,10 * time.Minute)
	if err != nil {
		log.Println(user.IP , "save Cache error , key: ",key , "error " , err)
		return false
	}
	return true
}

func (user *DnsUser)getCache(key string) ([]byte,error) {

	buf , err := db.Get(key)
	if err != nil {
		log.Println(user.IP,"get Cache error , key: " , key , " error: " , err)
		return nil,err
	}
	if "" != buf {
		return []byte(buf),nil
	}
	return nil,errors.New("nil")
}